from pyparsing import White
import Game
import model
import numpy as np
import random
import os
from tqdm import tqdm


def AlphaBeta(layers, max_status, ChessList, max_value, min_value, min=-float('inf'), max=float('inf')):
    rem_point = None
    if layers == 0:
        return GetValue(ChessList), rem_point
    else:
        if max_status:  # 求极大
            for point in GetNext(ChessList):
                ChessList[point] = max_value
                value, _ = AlphaBeta(
                    layers - 1, not max_status, ChessList, max_value, min_value, min, max)
                if value > min:
                    rem_point = point
                    min = value
                ChessList[point] = 0
                if(min >= max):
                    break
            return min, rem_point
        else:  # 求极小
            for point in GetNext(ChessList):
                ChessList[point] = min_value
                value, _ = AlphaBeta(
                    layers - 1, not max_status, ChessList, max_value, min_value, min, max)
                if value < max:
                    rem_point = point
                    max = value
                ChessList[point] = 0
                if(min >= max):
                    break
            return max, rem_point


def GetValue(ChessList):
    return model.Pre(ChessList)


def GetNext(ChessList):
    X, Y = ChessList.shape
    X_range = list(range(X))
    random.shuffle(X_range)
    Y_range = list(range(Y))
    random.shuffle(Y_range)
    for x in X_range:
        for y in Y_range:
            if ChessList[x, y] == 0:
                yield (x, y)
    return


def get_chess(ChessList, Role, layers=3):
    _, next_point = AlphaBeta(
        layers, Role == Game.white, ChessList, Role, -Role)
    return next_point


def autoBattle(randn):
    rem_ChessList = []
    ChessList = Game.initChessSquare()
    Role = Game.white
    count = 0
    while True:
        point = get_chess(ChessList, Role)
        if point == None:
            Role = 0
            break
        if random.randint(1, randn) == 1:
            temp = np.argwhere(ChessList == 0)
            point = temp[random.randint(0, temp.shape[0] - 1)]
            point = (point[0], point[1])
        count += 1
        ChessList[point] = Role
        rem_ChessList.append(ChessList.copy())
        if Game.judgeResult(point, Role, ChessList):
            break
        Role = -Role
    return rem_ChessList, Role * ((Game.board_size * Game.board_size - count) ** 2)


def Study(steps, max=250):

    if os.path.exists('rem_X.npy'):
        rem_X = np.load('rem_X.npy')
    else:
        rem_X = np.array([])
    if os.path.exists('rem_y.npy'):
        rem_y = np.load('rem_y.npy')
    else:
        rem_y = np.array([])

    model.Load()
    for i in tqdm(range(steps)):
        rem_ChessList, Count = autoBattle(i + 5)
        if Count == 0:
            Count == -1
        X = np.vstack(rem_ChessList).reshape(-1, 1,
                                             Game.board_size, Game.board_size)
        y = np.ones((X.shape[0], 1), np.float32) * Count
        X = np.concatenate((X, -X), axis=0)
        y = np.concatenate((y, -y), axis=0)

        if rem_X.shape == (0,):
            rem_X = X
        else:
            rem_X = np.concatenate((rem_X, X), axis=0)

        if rem_y.shape == (0,):
            rem_y = y
        else:
            rem_y = np.concatenate((rem_y, y), axis=0)

        if len(rem_X) > max:
            rem_X = rem_X[len(rem_X) - max:]
            rem_y = rem_y[len(rem_y) - max:]
            max += int(len(X) / 3)
            model.Train(rem_X, rem_y, 200)
            np.save('rem_X.npy', rem_X)
            np.save('rem_y.npy', rem_y)
            model.Save()


if __name__ == '__main__':
    Study(900)
